EXPLAINABLE MACHINE LEARNING¶
@Author Gabriel Schurr, Ilyesse Hettenbach
IMPORTS¶
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from dotenv import load_dotenv
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchsummary import summary
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from tqdm.notebook import tqdm
from torch_utils import get_loaders, transform
load_dotenv()
True
DATAPATH = str(os.getenv("DATAPATH"))
P_LABELS = os.path.join(DATAPATH, "images_labels.txt")
DATAPATH = os.path.join(DATAPATH, "animals")
print(f'Path to images: {DATAPATH}')
print(f'Path to labels: {P_LABELS}')
# P_LABELS = "D:\\Database\\animals\\original\\images_labels.txt"
# DATAPATH = "D:\\Database\\animals\\original\\animals"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# DEVICE = 'cpu'
BATCH_SIZE = 256
EPOCHS = 3
LR = 0.0007
RANDOM_SEED = 42
NUM_WORKERS = 0
print(f"Device: {DEVICE}")
Path to images: W:/Workspaces/Python/Python-Playground/src/Studium/ExplainableMLTrain/src/animal-image-dataset-90-different-animals/animals\animals Path to labels: W:/Workspaces/Python/Python-Playground/src/Studium/ExplainableMLTrain/src/animal-image-dataset-90-different-animals/animals\images_labels.txt Device: cuda
do_train = False
do_checkpoint = False
do_load_model = True
model_path = "best_model.pth"
EDA¶
data = []
with open(P_LABELS, 'r') as f:
for line in f:
image_path, label = line.strip().split(' ')
data.append({'image_path': image_path, 'label': label})
animal_df = pd.DataFrame(data)
animal_df.head()
| image_path | label | |
|---|---|---|
| 0 | W:/Workspaces/Python/Python-Playground/src/Stu... | antelope |
| 1 | W:/Workspaces/Python/Python-Playground/src/Stu... | antelope |
| 2 | W:/Workspaces/Python/Python-Playground/src/Stu... | antelope |
| 3 | W:/Workspaces/Python/Python-Playground/src/Stu... | antelope |
| 4 | W:/Workspaces/Python/Python-Playground/src/Stu... | antelope |
animal_df.describe()
| image_path | label | |
|---|---|---|
| count | 5375 | 5375 |
| unique | 5375 | 90 |
| top | W:/Workspaces/Python/Python-Playground/src/Stu... | antelope |
| freq | 1 | 60 |
fig, axes = plt.subplots(2, 3, figsize=(10, 6))
for i, ax in enumerate(axes.flat):
random_index = random.randint(0, len(animal_df)-1)
img = Image.open(animal_df['image_path'][random_index])
label = animal_df['label'][random_index]
ax.imshow(img)
ax.set_title(label)
ax.axis('off')
plt.tight_layout()
plt.show()
# Visualize class distribution
plt.figure(figsize=(20, 6))
sns.countplot(x='label', data=animal_df)
plt.title('Class Distribution')
plt.xticks(rotation=90)
plt.show()
MODEL¶
class CustomResNet18(nn.Module):
def __init__(self, num_classes=90):
super(CustomResNet18, self).__init__()
self.output = None
self.resnet = models.resnet18(weights='IMAGENET1K_V1') # 'IMAGENET1K_V1'
# for param in self.resnet.parameters():
# param.requires_grad = False
num_features = self.resnet.fc.in_features
self.resnet.fc = nn.Sequential(
nn.Linear(num_features, num_classes)
)
def forward(self, x):
self.output = self.resnet(x)
return self.output
def freeze_backbone(self):
for param in self.resnet.parameters():
param.requires_grad = False
for param in self.resnet.fc.parameters():
param.requires_grad = True
def unfreeze_backbone(self):
for param in self.resnet.parameters():
param.requires_grad = True
model = CustomResNet18()
if do_load_model and os.path.exists(model_path):
model = torch.load(model_path)
print("Model loaded")
summary(model, (3, 224, 224), device="cpu")
Model loaded
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 112, 112] 9,408
BatchNorm2d-2 [-1, 64, 112, 112] 128
ReLU-3 [-1, 64, 112, 112] 0
MaxPool2d-4 [-1, 64, 56, 56] 0
Conv2d-5 [-1, 64, 56, 56] 36,864
BatchNorm2d-6 [-1, 64, 56, 56] 128
ReLU-7 [-1, 64, 56, 56] 0
Conv2d-8 [-1, 64, 56, 56] 36,864
BatchNorm2d-9 [-1, 64, 56, 56] 128
ReLU-10 [-1, 64, 56, 56] 0
BasicBlock-11 [-1, 64, 56, 56] 0
Conv2d-12 [-1, 64, 56, 56] 36,864
BatchNorm2d-13 [-1, 64, 56, 56] 128
ReLU-14 [-1, 64, 56, 56] 0
Conv2d-15 [-1, 64, 56, 56] 36,864
BatchNorm2d-16 [-1, 64, 56, 56] 128
ReLU-17 [-1, 64, 56, 56] 0
BasicBlock-18 [-1, 64, 56, 56] 0
Conv2d-19 [-1, 128, 28, 28] 73,728
BatchNorm2d-20 [-1, 128, 28, 28] 256
ReLU-21 [-1, 128, 28, 28] 0
Conv2d-22 [-1, 128, 28, 28] 147,456
BatchNorm2d-23 [-1, 128, 28, 28] 256
Conv2d-24 [-1, 128, 28, 28] 8,192
BatchNorm2d-25 [-1, 128, 28, 28] 256
ReLU-26 [-1, 128, 28, 28] 0
BasicBlock-27 [-1, 128, 28, 28] 0
Conv2d-28 [-1, 128, 28, 28] 147,456
BatchNorm2d-29 [-1, 128, 28, 28] 256
ReLU-30 [-1, 128, 28, 28] 0
Conv2d-31 [-1, 128, 28, 28] 147,456
BatchNorm2d-32 [-1, 128, 28, 28] 256
ReLU-33 [-1, 128, 28, 28] 0
BasicBlock-34 [-1, 128, 28, 28] 0
Conv2d-35 [-1, 256, 14, 14] 294,912
BatchNorm2d-36 [-1, 256, 14, 14] 512
ReLU-37 [-1, 256, 14, 14] 0
Conv2d-38 [-1, 256, 14, 14] 589,824
BatchNorm2d-39 [-1, 256, 14, 14] 512
Conv2d-40 [-1, 256, 14, 14] 32,768
BatchNorm2d-41 [-1, 256, 14, 14] 512
ReLU-42 [-1, 256, 14, 14] 0
BasicBlock-43 [-1, 256, 14, 14] 0
Conv2d-44 [-1, 256, 14, 14] 589,824
BatchNorm2d-45 [-1, 256, 14, 14] 512
ReLU-46 [-1, 256, 14, 14] 0
Conv2d-47 [-1, 256, 14, 14] 589,824
BatchNorm2d-48 [-1, 256, 14, 14] 512
ReLU-49 [-1, 256, 14, 14] 0
BasicBlock-50 [-1, 256, 14, 14] 0
Conv2d-51 [-1, 512, 7, 7] 1,179,648
BatchNorm2d-52 [-1, 512, 7, 7] 1,024
ReLU-53 [-1, 512, 7, 7] 0
Conv2d-54 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-55 [-1, 512, 7, 7] 1,024
Conv2d-56 [-1, 512, 7, 7] 131,072
BatchNorm2d-57 [-1, 512, 7, 7] 1,024
ReLU-58 [-1, 512, 7, 7] 0
BasicBlock-59 [-1, 512, 7, 7] 0
Conv2d-60 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-61 [-1, 512, 7, 7] 1,024
ReLU-62 [-1, 512, 7, 7] 0
Conv2d-63 [-1, 512, 7, 7] 2,359,296
BatchNorm2d-64 [-1, 512, 7, 7] 1,024
ReLU-65 [-1, 512, 7, 7] 0
BasicBlock-66 [-1, 512, 7, 7] 0
AdaptiveAvgPool2d-67 [-1, 512, 1, 1] 0
Linear-68 [-1, 90] 46,170
ResNet-69 [-1, 90] 0
================================================================
Total params: 11,222,682
Trainable params: 11,222,682
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 62.79
Params size (MB): 42.81
Estimated Total Size (MB): 106.17
----------------------------------------------------------------
animal_df["class"] = animal_df["label"].astype("category").cat.codes
class_to_animal = animal_df[['class', 'label']].drop_duplicates().set_index('class').to_dict()["label"]
animal_to_class = {v: k for k, v in class_to_animal.items()}
animal_df.head()
| image_path | label | class | |
|---|---|---|---|
| 0 | W:/Workspaces/Python/Python-Playground/src/Stu... | antelope | 0 |
| 1 | W:/Workspaces/Python/Python-Playground/src/Stu... | antelope | 0 |
| 2 | W:/Workspaces/Python/Python-Playground/src/Stu... | antelope | 0 |
| 3 | W:/Workspaces/Python/Python-Playground/src/Stu... | antelope | 0 |
| 4 | W:/Workspaces/Python/Python-Playground/src/Stu... | antelope | 0 |
train_df = animal_df.sample(frac=0.8, random_state=42)
test_df = animal_df.drop(train_df.index)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)
fig, axes = plt.subplots(1, 2, figsize=(15, 4))
sns.countplot(x='label', data=train_df, ax=axes[0])
axes[0].set_xticks([])
axes[0].set_title('Train Class Distribution')
sns.countplot(x='label', data=test_df, ax=axes[1])
axes[1].set_xticks([])
axes[1].set_title('Test Class Distribution')
plt.tight_layout()
plt.show()
train_loader, val_loader = get_loaders(train_df, test_df, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=LR)
TRAINING¶
if do_train:
model.train().to(DEVICE)
train_losses = []
train_accs = []
with tqdm(total=EPOCHS, desc='Training') as pbar:
for epoch in range(1, EPOCHS+1):
running_loss = 0.0
for i, data in enumerate(train_loader):
idx, inputs, labels = data
inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_losses.append(loss.item())
train_accs.append((outputs.argmax(1) == labels).float().mean().item())
pbar.set_postfix({'batch': f'{i+1}/{len(train_loader)}', 'loss': f'{running_loss/(i+1):.3f}'})
pbar.update(1)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].plot(train_losses)
axes[0].set_title('Train Loss')
axes[1].plot(train_accs)
axes[1].set_title('Train Accuracy')
plt.show()
if do_checkpoint:
model.cpu().eval()
torch.save(model, model_path)
EVALUATION¶
model.eval().to(DEVICE)
pred_data = []
with tqdm(total=len(val_loader), desc='Validation') as pbar:
for data in val_loader:
idx, inputs, labels = data
inputs = inputs.to(DEVICE)
labels = labels.to(DEVICE)
with torch.no_grad():
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
for i in range(len(idx)):
pred_data.append(
{'idx': idx[i].item(),
'pred': predicted[i].cpu().detach().item(),
'correct': predicted[i].cpu().detach().item() == labels[i].cpu().detach().item()
})
pbar.update(1)
pred_df = pd.DataFrame(pred_data).set_index('idx')
pred_df['pred_animal'] = pred_df['pred'].map(class_to_animal)
pred_df = pd.merge(pred_df, test_df, left_index=True, right_index=True)
Validation: 0%| | 0/5 [00:00<?, ?it/s]
# plot accuracy by class
acc_by_class = pred_df.groupby('label')['correct'].mean().reset_index()
plt.figure(figsize=(20, 6))
sns.barplot(x='label', y='correct', data=acc_by_class)
plt.title('Accuracy by Class')
plt.xticks(rotation=90)
plt.show()
print(f'Mean Accuracy: {pred_df["correct"].mean():.3f} on validation set')
print('5 Lowest Accuracy Classes:')
display(acc_by_class.sort_values('correct').head(10).set_index('label'))
Mean Accuracy: 0.842 on validation set 5 Lowest Accuracy Classes:
| correct | |
|---|---|
| label | |
| dog | 0.454545 |
| rat | 0.500000 |
| ladybugs | 0.500000 |
| bat | 0.500000 |
| donkey | 0.571429 |
| squid | 0.600000 |
| squirrel | 0.625000 |
| mouse | 0.625000 |
| butterfly | 0.642857 |
| ox | 0.666667 |
GRAD-CAM¶
The following code is used to generate the Grad-CAM visualizations for the model. This technique is used to visualize the areas of the image that the model is focusing on when making a prediction. This is done by using the gradients of the model to generate a heatmap of the image. The heatmap is then superimposed on the original image to show the areas of the image that the model is focusing on.
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
def plot_grad_cam(model, img_path, transform, image_weight=0.5, device='cpu', targets=None):
target_layers = model.resnet.layer4
model = model.to(device)
model.unfreeze_backbone()
use_cuda = True if 'cuda' in str(device) else False
original_image = Image.open(img_path).convert('RGB')
input_tensor = transform(original_image).unsqueeze(0).to(device)
pred = model(input_tensor)
pred_class_idx = pred.argmax(1).item()
pred_animal = class_to_animal[pred_class_idx]
pred_prob = pred.softmax(1)[0, pred_class_idx].item()
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
targets = [ClassifierOutputTarget(targets)] if targets is not None else targets
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
grayscale_cam = cv2.resize(grayscale_cam, (original_image.size[0], original_image.size[1]))
np_image = np.array(original_image)
np_image = original_image - np.min(np_image)
np_image = original_image / np.max(np_image)
visualization = show_cam_on_image(np_image, grayscale_cam, use_rgb=True, image_weight=image_weight)
visualization = Image.fromarray(visualization)
return visualization, pred_animal, pred_prob
True Predictions
In the following we will show a few examples of correctly classified images and their corresponding Grad-CAM heatmaps.
images_to_show = 15
images_per_row = 5
cols = images_to_show // images_per_row
fig, axes = plt.subplots(cols, images_per_row, figsize=(20, 4*cols))
samples = pred_df[pred_df['correct'] == True].sample(images_to_show, random_state=38)
for i, sample in enumerate(samples.iterrows()):
img_path = sample[1]['image_path']
true_animal = sample[1]['label']
visualization, pred_animal, pred_prob = plot_grad_cam(model, img_path, transform, device=DEVICE, image_weight=0.6)
visualization = visualization.resize((224, 224))
axes[i//images_per_row, i%images_per_row].imshow(visualization)
axes[i//images_per_row, i%images_per_row].set_title(f'True Animal: {true_animal}\nPredicted Animal: {pred_animal} ({pred_prob*100:.0f}%)')
axes[i//images_per_row, i%images_per_row].axis('off')
plt.tight_layout()
plt.show()
False Predictions
In the following we will take a look at a few false predictions of our model. To explain the false prediction, we will use the Grad-CAM algorithm. With that approach we can visualize the regions of the image that were most important for the model to make its prediction.
images_to_show = 15
images_per_row = 5
cols = images_to_show // images_per_row
fig, axes = plt.subplots(cols, images_per_row, figsize=(20, 4*cols))
samples = pred_df[pred_df['correct'] == False].sample(images_to_show, random_state=38)
for i, sample in enumerate(samples.iterrows()):
img_path = sample[1]['image_path']
true_animal = sample[1]['label']
visualization, pred_animal, pred_prob = plot_grad_cam(model, img_path, transform, device=DEVICE, image_weight=0.6)
visualization = visualization.resize((224, 224))
axes[i//images_per_row][i%images_per_row].imshow(visualization)
axes[i//images_per_row][i%images_per_row].set_title(f'True Animal: {true_animal}\nPredicted Animal: {pred_animal} ({pred_prob*100:.0f}%)')
axes[i//images_per_row][i%images_per_row].axis('off')
plt.tight_layout()
plt.show()
False Predictions but true class given
In the following we will show some examples of false predictions of our model. But the gradients will be calculated for the true class. This is done to show that the model is actually looking at the right things but is just not able to make the correct prediction. So, it is possible to see the difference between the true class and the predicted class above in the heatmap.
images_to_show = 15
images_per_row = 5
cols = images_to_show // images_per_row
fig, axes = plt.subplots(cols, images_per_row, figsize=(20, 4*cols))
samples = pred_df[pred_df['correct'] == False].sample(images_to_show, random_state=38)
for i, sample in enumerate(samples.iterrows()):
img_path = sample[1]['image_path']
true_animal = sample[1]['label']
true_class = sample[1]['class']
visualization, pred_animal, pred_prob = plot_grad_cam(model, img_path, transform, device=DEVICE, targets=true_class, image_weight=0.6)
visualization = visualization.resize((224, 224))
axes[i//images_per_row][i%images_per_row].imshow(visualization)
axes[i//images_per_row][i%images_per_row].set_title(f'True Animal: {true_animal}\nPredicted Animal: {pred_animal} ({pred_prob*100:.0f}%)')
axes[i//images_per_row][i%images_per_row].axis('off')
plt.tight_layout()
plt.show()